{-#LANGUAGE TupleSections #-}
module AssertGen.SimpleAssertGen where

import Prelude hiding (pi)
import Data.List
import Data.Maybe (catMaybes)
import qualified Data.Set as S
import qualified Data.Map as M
import Control.Monad

import AssertGen.AST

arr :: [Term] -> Term -> Term
arr ts concl = foldr TArrow concl ts

pi :: [(String, Term)] -> Term -> Term
pi ts concl = foldr (uncurry TPi) concl ts

lam :: [(String, Term)] -> Term -> Term
lam ts concl = foldr (uncurry TLam) concl ts

app :: Term -> [Term] -> Term
app base args = foldl TApp base args

capp :: String -> [Term] -> Term
capp n args = (TConst n) `app` args

v :: String -> Term
v = TVar

c :: String -> Term
c = TConst

qc :: Term -> Term
qc t = "data" `capp` [t]

data' :: Term -> Term
data' t = "data" `capp` [t]

datap :: Term -> Term
datap t = "data+" `capp` [t]

hyp :: Term -> Term
hyp t = "hyp" `capp` [t]

castar :: Term -> Term -> Term -> Term -> Term
castar fp' d' e' sp' = "ca*" `capp` [fp', d', e', sp']

ax :: Term -> Term
ax t = "ax" `capp` [t]

fp :: Term
fp = v"FP"

d :: Term
d = v"D"

e :: Term
e = v"E"

defSeq :: [Int]
defSeq = [1..]

concV :: Term -> Term
concV t = "conc" `capp` [v"V", t]

conc' :: Term -> Term
conc' t = "conc'" `capp` [t]

concs :: Term -> Term
concs t = "conc*" `capp` [t]

data Fam = Fam {famName :: String, famArgs :: [String]}
  deriving (Show)
data Inhabitant = Inh { inhFam :: Fam
                      , inhName :: String
                      , inhExistentials :: [String]
                      , inhSubs :: [Term]
                      , inhConc :: [Term]
                      }
  deriving (Show)

make :: FilePath -> FilePath -> FilePath -> FilePath -> FilePath -> FilePath -> FilePath
     -> [(Fam, [Inhabitant])] -> IO ()
make dataFile assertFile admitFile elimFile normDeclFile normCasesFile abbrevFile fams = do
  mapM_ (flip writeFile []) fileList
  mapM_ (uncurry make') fams
  where
    fileList = [dataFile, assertFile, admitFile, elimFile, normDeclFile, normCasesFile, abbrevFile]
    make' fam is = do
      let unembName = famName fam ++ "-unemb"
      zipWithM_ appendFile fileList
                    $ map (showSig' (S.fromList ["ca*","ce",unembName]) M.empty)
                          [dataDecls, assertDecls, admitDecls, elimDecls
                          ,normDecls, normCases, abbrevs]
      mapM_ (flip appendFile "\n") fileList
      where
        (dataDecls, assertDecls') = unzip $ map right is
        assertDecls = assertDecls' ++ [left fam is]
        admitDecls = [essential fam is i | (_,i) <- zip is [0..]]
                     ++ [lcommLeft fam is]
                     ++ (catMaybes $ map rcommRight is)
                     ++ [rcommLeft fam is]
        elimDecls = (map elimRight is)
                    ++ [elimLeft fam is]
        (normDecls, normCases, _) = unembThm fam is
        abbrevs = [convLem fam is]

mangle :: Term -> Term
mangle (TConst ident) = TConst $ "@"++ident
mangle (TAscribe t1 t2) = TAscribe (mangle t1) t2
mangle (TPi ident t1 t2) = TPi ident t1 (mangle t2)
mangle (TLam ident t1 t2) = TLam ident t1 (mangle t2)
mangle (TArrow t1 t2) = TArrow t1 (mangle t2)
mangle (TApp t1 t2) = TApp (mangle t1) t2
mangle t = t

concFam :: Term -> String
concFam t | (TConst x) <- root (conc t) = x
concFam _ = error "Could not extract type family name"

right :: Inhabitant -> (Decl, Decl)
right inh =
    (dataDecl, assertDecl)
    where
      dataDecl = (DDecl
                  (formFam)
                  (subs `arr` concl))
      assertDecl = (DDecl
                    ("data+"++"/"++formFam)
                    (assertSubs `arr` assertConc))

      fam = famName . inhFam $ inh
      formFam = "@" ++ fam ++ "/" ++ inhName inh

      concl = "data" `capp`
             [("@" ++ fam) `capp` [v $ "X" ++ show i | i <- [1..length (inhConc inh)]]]
      subs = dataSubs ++ eqSubs
      dataSubs = [ "data" `capp` [mangle s] | s <- inhSubs inh]

      eqSubs = ["data" `capp` [("@eq-"++a) `capp` [v $ "X" ++ show i, m]]
                     | (m,a,i) <- zip3 (inhConc inh) (famArgs . inhFam $ inh) defSeq]

      assertSubs = [concV ("data+" `capp` [v$"DP"++show i])
                        | (_,i) <- zip (inhSubs inh) defSeq]
      assertConc =
          concV ("data+"
                 `capp`
                 [formFam
                  `capp` ([v$"DP"++show i | (_,i) <- zip (inhSubs inh) defSeq]
                          ++ [v$"Q"++show i | (_,i) <- zip (inhConc inh) defSeq])])

left :: Fam -> [Inhabitant] -> Decl
left fam is = DDecl leftName $ branches `arr` concl
    where
      fname = "@" ++ famName fam
      leftName = "data+/"++fname++"/l"
      concl = hyp (datap (THole `TAscribe`
                                data' (fname `capp` [v$"X"++show i
                                                    | (_,i) <- zip (famArgs fam) defSeq])))
             `TArrow` concV (v"C")
      branches = map leftBranch is
      leftBranch inh = let (env,_) = lcase inh in env `pi` concV (v"C")

lcase :: Inhabitant -> ([(String, Term)], Term)
lcase inh = (env, env `pi` concV (v"C"))
    where
      env = exis ++ subs ++ hyps ++ eqs
      exis = [(var,THole) | var <- inhExistentials inh]
      subs = [("dp"++show i,data' (mangle t)) | (t,i) <- zip (inhSubs inh) defSeq]
      hyps = [("h"++show i
              , hyp (datap (v dp))) | ((dp,_),i) <- zip subs defSeq]
      eqs = [("q"++show i,"data" `capp` [("@eq-"++a) `capp` [v$"X"++show i, m]])
                 | (a,m,i) <- zip3 (famArgs . inhFam $ inh) (inhConc inh) defSeq]

lcommLeft :: Fam -> [Inhabitant] -> Decl
lcommLeft fam is = DDecl "-" $ subs `arr` concl
    where
      fname = "@" ++ famName fam
      leftName = "data+/"++fname++"/l"
      h = v"H"
      concl = castar fp (leftName `capp` (vins ++ [h])) e
                        (leftName `capp` (vouts ++ [h]))
      ins = ["SP"++show i | (_,i) <- zip is defSeq]
      outs = map (++"'") ins
      vins = map v ins
      vouts = map v outs
      subs = [lcommBranch i inh | (inh,i) <- zip is defSeq]
      lcommBranch ix inh = env `pi` (castar fp input e output)
          where
            (env,_) = lcase inh
            sp = v$"SP"++show ix
            sp' = v$"SP"++show ix++"'"
            input = sp `app` [v (fst arg) | arg <- env]
            output = sp' `app` [v (fst arg) | arg <- env]

rcommLeft :: Fam -> [Inhabitant] -> Decl
rcommLeft fam is = DDecl "-" $ subs `arr` concl
    where
      fname = "@" ++ famName fam
      leftName = "data+/"++fname++"/l"
      h = v"H"
      concl = castar fp d (TLam "h" THole $ leftName `capp` (subderivs ++ [h]))
                          (leftName `capp` (subderivs' ++ [h]))
      subderivs = [TApp (v$"SP"++show i) (v"h") | (_,i) <- zip is defSeq]
      subderivs' = [v$"SP"++show i++"'" | (_,i) <- zip is defSeq]
      subs = [let (env,_) = lcase inh
              in env `pi` castar fp d (TLam "h" THole (sub `app` (map (v.fst) env)))
                                      (sub' `app` (map (v.fst) env))
                  | (inh,sub,sub') <- zip3 is subderivs subderivs']

rcommRight :: Inhabitant -> Maybe Decl
rcommRight inh =
    if length (inhSubs inh) == 0
    then Nothing
    else Just $ DDecl "-" $ subs `arr` concl
    where
      inhCtor = "@"++(famName . inhFam $ inh)++"/"++(inhName inh)
      ctor = "data+/"++inhCtor
      concl = castar fp d (TLam "h" THole (ctor `capp` subderivs))
                          ((ctor `capp` subderivs')
                           `TAscribe` (conc' (datap ascr)))
      ascr = inhCtor `capp` ([THole | _ <- inhSubs inh]
                            ++ [(v$"Q"++show i)
                                `TAscribe` (qc $ ("@eq-"++tp) `capp` [v$"X"++show i, m])
                                    | (tp,m,i) <- zip3 (famArgs . inhFam $ inh)
                                                       (inhConc inh)
                                                       defSeq])
      subderivs = [TApp (v$"SP"++show i) (v"h")
                       | (_,i) <- zip (inhSubs inh) defSeq]
      subderivs' = [v$"SP"++show i++"'"
                       | (_,i) <- zip (inhSubs inh) defSeq]
      subs = [castar fp d (v$"SP"++show i) (v$"SP"++show i++"'")
                  | (_,i) <- zip (inhSubs inh) defSeq]


essential :: Fam -> [Inhabitant] -> Int -> Decl
essential fam is ix =
    DDecl "-" $ subs `arr` concl
    where
      inh = is !! ix
      fname = "@" ++ famName fam
      fctor = fname ++ "/" ++ inhName inh
      leftName = "data+/"++fname++"/l"
      rightName = "data+/"++fname++"/"++inhName inh
      concl = castar (TConst"msre/data+") lderiv rderiv (v"OUT1")
      exiImplicits = [THole | _ <- inhExistentials inh]
      dataps = [v$"SP"++show i++"+" | (_,i) <- zip (inhSubs inh) defSeq]
      lderiv = rightName `capp` dataps
      rderiv = TLam "h"
               (hyp $ datap $ fctor `capp` (datas ++ eqs))
               (leftName `capp` (subderivs ++ [v"h"]))
      subderivs = [TApp (v$"SP"++show i) (v"h") | (_,i) <- zip is defSeq]
      datas = [v$"DP"++show i | (_,i) <- zip (inhSubs inh) defSeq]
      eqs = [v$"Q"++show i | (_,i) <- zip (famArgs fam) defSeq]

      hypseqs = tails ["h"++show i | (_,i) <- zip (inhSubs inh) defSeq]
      subOutputs = reverse [(v$"OUT"++show i) `app` (map v hyps)
                                | (hyps,i) <- zip (reverse hypseqs) defSeq]
      subInputs = [TLam "h" THole $
                   (v$"SP"++show (ix+1)) `app` ([v"h"]
                                                ++ exiImplicits
                                                ++ datas
                                                ++ (map v$head hypseqs)
                                                ++ eqs)]
                  ++ (map hoist $ reverse $ tail $ reverse subOutputs)

      -- Hoists the innermost argument, which must be a variable, and binds it
      hoist t = TLam (hoist' t) THole t
      hoist' (TApp (TVar _) (TVar h)) = h
      hoist' (TApp t1 _) = hoist' t1
      hoist' _ = error "Impossible"

      mksub hyps p input output =
          (map (,THole) hyps)
          `pi` (castar (TConst"msre/data+") p input output)

      subs = reverse $ zipWith4 mksub hypseqs (lderiv:dataps) subInputs subOutputs

elimRight :: Inhabitant -> Decl
elimRight inh = DDecl "-" $ subs `arr` concl
    where
      concl = "ce" `capp` [input, output]
      rightName = "data+/@"++(famName . inhFam $ inh)++"/"++(inhName inh)
      args = ["SP"++show i | (_,i) <- zip (inhSubs inh) defSeq]
      input = rightName `capp` [v arg | arg <- args]
      output = rightName `capp` [v (arg++"'") | arg <- args]
      subs = ["ce" `capp` [v arg, v (arg++"'")] | arg <- args]

elimLeft :: Fam -> [Inhabitant] -> Decl
elimLeft fam is = DDecl "-" $ subs `arr` concl
    where
      rightName = "data+/@"++(famName fam)++"/l"
      concl = "ce" `capp` [input, output]
      args = ["SP"++show i | (_,i) <- zip is defSeq]
      input = rightName `capp` ([v arg | arg <- args] ++ [v"H"])
      output = rightName `capp` ([v (arg++"'") | arg <- args] ++ [v"H"])
      subs = [let (env,_) = lcase inh
                  env' = map (v . fst) env
                  arg' = arg ++ "'"
              in env `pi` ("ce" `capp` [arg `capp` env', arg' `capp` env'])
                  | (inh,arg) <- zip is args]

unembLem :: Inhabitant -> [Decl]
unembLem inh = [DDecl lemName lemSig
               ,DMode lemName modeArgs
               ,DDecl "-" body
               ,DWorlds [] [(lemName, ["_" | _ <- modeArgs])]]
               --,DTotal (Lexicographic []) [(lemName, ["_" | _ <- modeArgs])]]
    where
      ctor = (famName . inhFam $ inh)++"/"++(inhName inh)
      lemName = (famName . inhFam $ inh)++"-unemb/"++(inhName inh)
      lemSig = (inhSubs inh ++ eqSubs ++ [conclSig]) `arr` TType
      conclSig = (famName . inhFam $ inh)
                 `capp` [v$"X"++show i
                             | (i,_) <- zip defSeq (famArgs . inhFam $ inh)]
      eqSubs = [("eq-"++a) `capp` [v $ "X" ++ show i, m]
                     | (m,a,i) <- zip3 (inhConc inh) (famArgs . inhFam $ inh) defSeq]
      modeArgs = [(Input, "X"++show i) | (i,_) <- zip defSeq (inhSubs inh ++ eqSubs)]
                 ++ [(Output, "OUT")]

      xargs = [v$"X"++show i| (i,_) <- zip defSeq (inhSubs $ inh)]

      body = lemName
             `capp` (xargs
                     ++ [c$"eq-"++a++"/id" | a <- (famArgs . inhFam $ inh)]
                     ++ [ctor `capp` xargs])

unembThm :: Fam -> [Inhabitant] -> ([Decl], [Decl], [Decl])
unembThm fam is = ([DDecl lemName lemSig
                   ,DMode lemName modeArgs]
                  ,cases
                  ,[DWorlds [] [(lemName, ["_" | _ <- modeArgs])]
                   ,DTotal (Lexicographic []) [(lemName, ["_" | _ <- modeArgs])]])
    where
      lemName = (famName fam)++"-unemb"
      lemSig = [data' (mangle $ famSig), famSig] `arr` TType
      modeArgs = [(Input, "IN"), (Output, "OUT")]
      famSig = (famName fam)
               `capp` [v$"X"++show i | (_,i) <- zip (famArgs fam) defSeq]

      cases = do
        inh <- is
        let xargs = [v$"X"++show i
                         | (i,_) <- zip defSeq (inhSubs inh)]
        let yargs = [v$"Y"++show i
                      | (i,_) <- zip defSeq (inhSubs inh)]
        let qargs = [v$"Q"++show i | (i,_) <- zip defSeq (famArgs fam)]
        let qpargs = [v$"QP"++show i | (i,_) <- zip defSeq (famArgs fam)]
        let ctor = (famName . inhFam $ inh) ++ "/" ++ inhName inh
        let conclSig = ctor `capp` (xargs ++ qargs)
        let concl = lemName `capp` [mangle conclSig, v"OUT"]
        let subs =
                (concat
                 [[("eq-"++a++"-norm") `capp` [v$"Q"++show i
                                              ,v$"QP"++show i++"s"]
                  ,("eq-"++a++"-sym") `capp` [v$"QP"++show i++"s", v$"QP"++show i]]
                 | (a,i) <- zip (famArgs fam) defSeq])
             ++ [((concFam sub)++"-unemb") `capp` [v$"X"++show i, v$"Y"++show i]
                | (sub,i) <- zip (inhSubs inh) defSeq]
             ++ [(famName fam++"-conv") `capp`
                                        ([ctor `capp` yargs] ++ qpargs ++ [v"OUT"])]
        return $ DDecl "-" $ (reverse subs) `arr` concl


vars :: String -> Int -> [String]
vars x cnt = [x ++ show i | i <- [1..cnt]]

vars' :: String -> Int -> [Term]
vars' x cnt = map v $ vars x cnt

convLem :: Fam -> [Inhabitant] -> Decl
convLem fam is = DDefn False ("#" ++ famName fam ++ "-conv") a m
    where
      hashName = "#" ++ famName fam
      leftName = "data+/@"++famName fam ++ "/l"
      a = [qc $ ("@eq-"++t) `capp` [in', out']
            | (in',out',t) <- zip3 ins outs (famArgs fam)]
          `arr` (concs (hashName `capp` ins) `TArrow` concs (hashName `capp` outs))
          where
            ins = vars' "X" $ length $ famArgs fam
            outs = vars' "Y" $ length $ famArgs fam
      m = ([(x, THole) | x <- qs]
           ++ [("#sp", THole)])
          `lam` ("cut" `capp` [v"#sp", "existsdl"
                                       `capp` [TLam "sp" THole (leftName `capp` cases)]])
          where
            qs = vars "q" (length (famArgs fam))

            trans n t1 t2 = ("@eq-"++n++"/trans") `capp` [t1, t2]
            sym n t = ("@eq-"++n++"/sym") `capp` [t]

            cases = do
              inh <- is
              let ctor = "@" ++ famName fam ++ "/" ++ inhName inh
              let dataCtor = "data+/" ++ ctor
              let exis = [(var,THole) | var <- inhExistentials inh]
              let subs = [("dp"++show i,data' (mangle t))
                              | (t,i) <- zip (inhSubs inh) defSeq]
              let hyps = [("h"++show i, hyp (datap (v dp)))
                              | ((dp,_),i) <- zip subs defSeq]
              let eqs = zip (map (++"'") qs) (repeat THole)
              return $ (exis ++ subs ++ hyps ++ eqs)
                         `lam` ("existsdr" `capp`
                                [ctor `capp`
                                     (map (v.fst) subs
                                      ++[trans tp (sym tp (v q)) (v q')
                                        | (tp,q,q') <- zip3 (famArgs fam)
                                                            qs
                                                            (map (++"'") qs)])
                                ,dataCtor `capp` (map (ax.v.fst) hyps)])

leftLem :: Fam -> [Inhabitant] -> Int -> Decl
leftLem fam is ix = DDefn False ("data+/@"++famName fam++"/"++inhName inh++"/l") a m
    where
      inh = is !! ix
      a = [qc (c"@void") `TArrow` (concs (v"C"))
          ,env `pi` (concs (v"C"))]
          `arr`
          (hyp
           (datap
            (THole `TAscribe` data' (atName `capp` (inhConc inh)))) `TArrow` (concs (v"C")))
      (env,_) = lcase inh
      atName = "@" ++ famName fam
      m = THole